"""Support for MQTT message handling."""
from __future__ import annotations

import asyncio
from collections.abc import Callable
import logging
from typing import Any, cast

import jinja2
import voluptuous as vol

from homeassistant import config as conf_util, config_entries
from homeassistant.components import websocket_api
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
    CONF_DISCOVERY,
    CONF_PASSWORD,
    CONF_PAYLOAD,
    CONF_PORT,
    CONF_USERNAME,
    SERVICE_RELOAD,
)
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.helpers import (
    config_validation as cv,
    discovery_flow,
    event,
    template,
)
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import async_get_platforms
from homeassistant.helpers.reload import (
    async_integration_yaml_config,
    async_reload_integration_platforms,
)
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType

# Loading the config flow file will register the flow
from . import debug_info, discovery
from .client import (  # noqa: F401
    MQTT,
    async_publish,
    async_subscribe,
    publish,
    subscribe,
)
from .config_integration import (
    CONFIG_SCHEMA_BASE,
    DEFAULT_VALUES,
    DEPRECATED_CONFIG_KEYS,
)
from .const import (  # noqa: F401
    ATTR_PAYLOAD,
    ATTR_QOS,
    ATTR_RETAIN,
    ATTR_TOPIC,
    CONF_BIRTH_MESSAGE,
    CONF_BROKER,
    CONF_COMMAND_TOPIC,
    CONF_DISCOVERY_PREFIX,
    CONF_QOS,
    CONF_STATE_TOPIC,
    CONF_TLS_VERSION,
    CONF_TOPIC,
    CONF_WILL_MESSAGE,
    DATA_MQTT,
    DEFAULT_ENCODING,
    DEFAULT_QOS,
    DEFAULT_RETAIN,
    DOMAIN,
    MQTT_CONNECTED,
    MQTT_DISCONNECTED,
    PLATFORMS,
    RELOADABLE_PLATFORMS,
)
from .models import (  # noqa: F401
    MqttCommandTemplate,
    MqttValueTemplate,
    PublishPayloadType,
    ReceiveMessage,
    ReceivePayloadType,
)
from .util import (
    _VALID_QOS_SCHEMA,
    get_mqtt_data,
    mqtt_config_entry_enabled,
    valid_publish_topic,
    valid_subscribe_topic,
)

_LOGGER = logging.getLogger(__name__)

SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump"

MANDATORY_DEFAULT_VALUES = (CONF_PORT,)

ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template"

MAX_RECONNECT_WAIT = 300  # seconds

CONNECTION_SUCCESS = "connection_success"
CONNECTION_FAILED = "connection_failed"
CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable"

CONFIG_ENTRY_CONFIG_KEYS = [
    CONF_BIRTH_MESSAGE,
    CONF_BROKER,
    CONF_DISCOVERY,
    CONF_PASSWORD,
    CONF_PORT,
    CONF_USERNAME,
    CONF_WILL_MESSAGE,
]

CONFIG_SCHEMA = vol.Schema(
    {
        DOMAIN: vol.All(
            cv.deprecated(CONF_BIRTH_MESSAGE),  # Deprecated in HA Core 2022.3
            cv.deprecated(CONF_BROKER),  # Deprecated in HA Core 2022.3
            cv.deprecated(CONF_DISCOVERY),  # Deprecated in HA Core 2022.3
            cv.deprecated(CONF_PASSWORD),  # Deprecated in HA Core 2022.3
            cv.deprecated(CONF_PORT),  # Deprecated in HA Core 2022.3
            cv.deprecated(CONF_TLS_VERSION),  # Deprecated June 2020
            cv.deprecated(CONF_USERNAME),  # Deprecated in HA Core 2022.3
            cv.deprecated(CONF_WILL_MESSAGE),  # Deprecated in HA Core 2022.3
            CONFIG_SCHEMA_BASE,
        )
    },
    extra=vol.ALLOW_EXTRA,
)


# Service call validation schema
MQTT_PUBLISH_SCHEMA = vol.All(
    vol.Schema(
        {
            vol.Exclusive(ATTR_TOPIC, CONF_TOPIC): valid_publish_topic,
            vol.Exclusive(ATTR_TOPIC_TEMPLATE, CONF_TOPIC): cv.string,
            vol.Exclusive(ATTR_PAYLOAD, CONF_PAYLOAD): cv.string,
            vol.Exclusive(ATTR_PAYLOAD_TEMPLATE, CONF_PAYLOAD): cv.string,
            vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA,
            vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean,
        },
        required=True,
    ),
    cv.has_at_least_one_key(ATTR_TOPIC, ATTR_TOPIC_TEMPLATE),
)


async def _async_setup_discovery(
    hass: HomeAssistant, conf: ConfigType, config_entry
) -> None:
    """Try to start the discovery of MQTT devices.

    This method is a coroutine.
    """
    await discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config_entry)


async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
    """Start the MQTT protocol service."""
    mqtt_data = get_mqtt_data(hass, True)

    conf: ConfigType | None = config.get(DOMAIN)

    websocket_api.async_register_command(hass, websocket_subscribe)
    websocket_api.async_register_command(hass, websocket_mqtt_info)

    if conf:
        conf = dict(conf)
        mqtt_data.config = conf

    if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is None:
        # Create an import flow if the user has yaml configured entities etc.
        # but no broker configuration. Note: The intention is not for this to
        # import broker configuration from YAML because that has been deprecated.
        discovery_flow.async_create_flow(
            hass,
            DOMAIN,
            context={"source": config_entries.SOURCE_INTEGRATION_DISCOVERY},
            data={},
        )
        mqtt_data.reload_needed = True
    elif mqtt_entry_status is False:
        _LOGGER.info(
            "MQTT will be not available until the config entry is enabled",
        )
        mqtt_data.reload_needed = True

    return True


def _filter_entry_config(hass: HomeAssistant, entry: ConfigEntry) -> None:
    """Remove unknown keys from config entry data.

    Extra keys may have been added when importing MQTT yaml configuration.
    """
    filtered_data = {
        k: entry.data[k] for k in CONFIG_ENTRY_CONFIG_KEYS if k in entry.data
    }
    if entry.data.keys() != filtered_data.keys():
        _LOGGER.warning(
            "The following unsupported configuration options were removed from the "
            "MQTT config entry: %s. Add them to configuration.yaml if they are needed",
            entry.data.keys() - filtered_data.keys(),
        )
        hass.config_entries.async_update_entry(entry, data=filtered_data)


def _merge_basic_config(
    hass: HomeAssistant, entry: ConfigEntry, yaml_config: dict[str, Any]
) -> None:
    """Merge basic options in configuration.yaml config with config entry.

    This mends incomplete migration from old version of HA Core.
    """

    entry_updated = False
    entry_config = {**entry.data}
    for key in DEPRECATED_CONFIG_KEYS:
        if key in yaml_config and key not in entry_config:
            entry_config[key] = yaml_config[key]
            entry_updated = True

    for key in MANDATORY_DEFAULT_VALUES:
        if key not in entry_config:
            entry_config[key] = DEFAULT_VALUES[key]
            entry_updated = True

    if entry_updated:
        hass.config_entries.async_update_entry(entry, data=entry_config)


def _merge_extended_config(entry, conf):
    """Merge advanced options in configuration.yaml config with config entry."""
    # Add default values
    conf = {**DEFAULT_VALUES, **conf}
    return {**conf, **entry.data}


async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
    """Handle signals of config entry being updated.

    Causes for this is config entry options changing.
    """
    await hass.config_entries.async_reload(entry.entry_id)


async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None:
    """Fetch fresh MQTT yaml config from the hass config when (re)loading the entry."""
    mqtt_data = get_mqtt_data(hass)
    if mqtt_data.reload_entry:
        hass_config = await conf_util.async_hass_config_yaml(hass)
        mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))

    # Remove unknown keys from config entry data
    _filter_entry_config(hass, entry)

    # Merge basic configuration, and add missing defaults for basic options
    _merge_basic_config(hass, entry, mqtt_data.config or {})
    # Bail out if broker setting is missing
    if CONF_BROKER not in entry.data:
        _LOGGER.error("MQTT broker is not configured, please configure it")
        return None

    # If user doesn't have configuration.yaml config, generate default values
    # for options not in config entry data
    if (conf := mqtt_data.config) is None:
        conf = CONFIG_SCHEMA_BASE(dict(entry.data))

    # User has configuration.yaml config, warn about config entry overrides
    elif any(key in conf for key in entry.data):
        shared_keys = conf.keys() & entry.data.keys()
        override = {k: entry.data[k] for k in shared_keys if conf[k] != entry.data[k]}
        if CONF_PASSWORD in override:
            override[CONF_PASSWORD] = "********"
        if override:
            _LOGGER.warning(
                "Deprecated configuration settings found in configuration.yaml. "
                "These settings from your configuration entry will override: %s",
                override,
            )

    # Merge advanced configuration values from configuration.yaml
    conf = _merge_extended_config(entry, conf)
    return conf


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
    """Load a config entry."""
    mqtt_data = get_mqtt_data(hass, True)

    # Merge basic configuration, and add missing defaults for basic options
    if (conf := await async_fetch_config(hass, entry)) is None:
        # Bail out
        return False
    mqtt_data.client = MQTT(hass, entry, conf)
    # Restore saved subscriptions
    if mqtt_data.subscriptions_to_restore:
        mqtt_data.client.subscriptions = mqtt_data.subscriptions_to_restore
        mqtt_data.subscriptions_to_restore = []
    mqtt_data.reload_dispatchers.append(
        entry.add_update_listener(_async_config_entry_updated)
    )

    await mqtt_data.client.async_connect()

    async def async_publish_service(call: ServiceCall) -> None:
        """Handle MQTT publish service calls."""
        msg_topic = call.data.get(ATTR_TOPIC)
        msg_topic_template = call.data.get(ATTR_TOPIC_TEMPLATE)
        payload = call.data.get(ATTR_PAYLOAD)
        payload_template = call.data.get(ATTR_PAYLOAD_TEMPLATE)
        qos: int = call.data[ATTR_QOS]
        retain: bool = call.data[ATTR_RETAIN]
        if msg_topic_template is not None:
            try:
                rendered_topic = template.Template(
                    msg_topic_template, hass
                ).async_render(parse_result=False)
                msg_topic = valid_publish_topic(rendered_topic)
            except (jinja2.TemplateError, TemplateError) as exc:
                _LOGGER.error(
                    "Unable to publish: rendering topic template of %s "
                    "failed because %s",
                    msg_topic_template,
                    exc,
                )
                return
            except vol.Invalid as err:
                _LOGGER.error(
                    "Unable to publish: topic template '%s' produced an "
                    "invalid topic '%s' after rendering (%s)",
                    msg_topic_template,
                    rendered_topic,
                    err,
                )
                return

        if payload_template is not None:
            try:
                payload = MqttCommandTemplate(
                    template.Template(payload_template), hass=hass
                ).async_render()
            except (jinja2.TemplateError, TemplateError) as exc:
                _LOGGER.error(
                    "Unable to publish to %s: rendering payload template of "
                    "%s failed because %s",
                    msg_topic,
                    payload_template,
                    exc,
                )
                return

        assert mqtt_data.client is not None and msg_topic is not None
        await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)

    hass.services.async_register(
        DOMAIN, SERVICE_PUBLISH, async_publish_service, schema=MQTT_PUBLISH_SCHEMA
    )

    async def async_dump_service(call: ServiceCall) -> None:
        """Handle MQTT dump service calls."""
        messages = []

        @callback
        def collect_msg(msg):
            messages.append((msg.topic, msg.payload.replace("\n", "")))

        unsub = await async_subscribe(hass, call.data["topic"], collect_msg)

        def write_dump():
            with open(hass.config.path("mqtt_dump.txt"), "wt", encoding="utf8") as fp:
                for msg in messages:
                    fp.write(",".join(msg) + "\n")

        async def finish_dump(_):
            """Write dump to file."""
            unsub()
            await hass.async_add_executor_job(write_dump)

        event.async_call_later(hass, call.data["duration"], finish_dump)

    hass.services.async_register(
        DOMAIN,
        SERVICE_DUMP,
        async_dump_service,
        schema=vol.Schema(
            {
                vol.Required("topic"): valid_subscribe_topic,
                vol.Optional("duration", default=5): int,
            }
        ),
    )

    # setup platforms and discovery

    async def async_setup_reload_service() -> None:
        """Create the reload service for the MQTT domain."""
        if hass.services.has_service(DOMAIN, SERVICE_RELOAD):
            return

        async def _reload_config(call: ServiceCall) -> None:
            """Reload the platforms."""
            # Reload the legacy yaml platform
            await async_reload_integration_platforms(hass, DOMAIN, RELOADABLE_PLATFORMS)

            # Reload the modern yaml platforms
            mqtt_platforms = async_get_platforms(hass, DOMAIN)
            tasks = [
                entity.async_remove()
                for mqtt_platform in mqtt_platforms
                for entity in mqtt_platform.entities.values()
                if not entity._discovery_data  # type: ignore[attr-defined] # pylint: disable=protected-access
                if mqtt_platform.config_entry
                and mqtt_platform.domain in RELOADABLE_PLATFORMS
            ]
            await asyncio.gather(*tasks)

            config_yaml = await async_integration_yaml_config(hass, DOMAIN) or {}
            mqtt_data.updated_config = config_yaml.get(DOMAIN, {})
            await asyncio.gather(
                *(
                    [
                        mqtt_data.reload_handlers[component]()
                        for component in RELOADABLE_PLATFORMS
                        if component in mqtt_data.reload_handlers
                    ]
                )
            )

            # Fire event
            hass.bus.async_fire(f"event_{DOMAIN}_reloaded", context=call.context)

        async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _reload_config)

    async def async_forward_entry_setup_and_setup_discovery(config_entry):
        """Forward the config entry setup to the platforms and set up discovery."""
        reload_manual_setup: bool = False
        # Local import to avoid circular dependencies
        # pylint: disable-next=import-outside-toplevel
        from . import device_automation, tag

        # Forward the entry setup to the MQTT platforms
        await asyncio.gather(
            *(
                [
                    device_automation.async_setup_entry(hass, config_entry),
                    tag.async_setup_entry(hass, config_entry),
                ]
                + [
                    hass.config_entries.async_forward_entry_setup(entry, component)
                    for component in PLATFORMS
                ]
            )
        )
        # Setup discovery
        if conf.get(CONF_DISCOVERY):
            await _async_setup_discovery(hass, conf, entry)
        # Setup reload service after all platforms have loaded
        await async_setup_reload_service()
        # When the entry is reloaded, also reload manual set up items to enable MQTT
        if mqtt_data.reload_entry:
            mqtt_data.reload_entry = False
            reload_manual_setup = True

        # When the entry was disabled before, reload manual set up items to enable MQTT again
        if mqtt_data.reload_needed:
            mqtt_data.reload_needed = False
            reload_manual_setup = True

        if reload_manual_setup:
            await async_reload_manual_mqtt_items(hass)

    await async_forward_entry_setup_and_setup_discovery(entry)

    return True


async def async_reload_manual_mqtt_items(hass: HomeAssistant) -> None:
    """Reload manual configured MQTT items."""
    await hass.services.async_call(
        DOMAIN,
        SERVICE_RELOAD,
        {},
        blocking=True,
    )


@websocket_api.websocket_command(
    {vol.Required("type"): "mqtt/device/debug_info", vol.Required("device_id"): str}
)
@callback
def websocket_mqtt_info(hass, connection, msg):
    """Get MQTT debug info for device."""
    device_id = msg["device_id"]
    mqtt_info = debug_info.info_for_device(hass, device_id)

    connection.send_result(msg["id"], mqtt_info)


@websocket_api.websocket_command(
    {
        vol.Required("type"): "mqtt/subscribe",
        vol.Required("topic"): valid_subscribe_topic,
    }
)
@websocket_api.async_response
async def websocket_subscribe(hass, connection, msg):
    """Subscribe to a MQTT topic."""
    if not connection.user.is_admin:
        raise Unauthorized

    async def forward_messages(mqttmsg: ReceiveMessage):
        """Forward events to websocket."""
        try:
            payload = cast(bytes, mqttmsg.payload).decode(
                DEFAULT_ENCODING
            )  # not str because encoding is set to None
        except (AttributeError, UnicodeDecodeError):
            # Convert non UTF-8 payload to a string presentation
            payload = str(mqttmsg.payload)

        connection.send_message(
            websocket_api.event_message(
                msg["id"],
                {
                    "topic": mqttmsg.topic,
                    "payload": payload,
                    "qos": mqttmsg.qos,
                    "retain": mqttmsg.retain,
                },
            )
        )

    # Perform UTF-8 decoding directly in callback routine
    connection.subscriptions[msg["id"]] = await async_subscribe(
        hass, msg["topic"], forward_messages, encoding=None
    )

    connection.send_message(websocket_api.result_message(msg["id"]))


ConnectionStatusCallback = Callable[[bool], None]


@callback
def async_subscribe_connection_status(
    hass: HomeAssistant, connection_status_callback: ConnectionStatusCallback
) -> Callable[[], None]:
    """Subscribe to MQTT connection changes."""
    connection_status_callback_job = HassJob(connection_status_callback)

    async def connected():
        task = hass.async_run_hass_job(connection_status_callback_job, True)
        if task:
            await task

    async def disconnected():
        task = hass.async_run_hass_job(connection_status_callback_job, False)
        if task:
            await task

    subscriptions = {
        "connect": async_dispatcher_connect(hass, MQTT_CONNECTED, connected),
        "disconnect": async_dispatcher_connect(hass, MQTT_DISCONNECTED, disconnected),
    }

    @callback
    def unsubscribe():
        subscriptions["connect"]()
        subscriptions["disconnect"]()

    return unsubscribe


def is_connected(hass: HomeAssistant) -> bool:
    """Return if MQTT client is connected."""
    mqtt_data = get_mqtt_data(hass)
    assert mqtt_data.client is not None
    return mqtt_data.client.connected


async def async_remove_config_entry_device(
    hass: HomeAssistant, config_entry: ConfigEntry, device_entry: DeviceEntry
) -> bool:
    """Remove MQTT config entry from a device."""
    # pylint: disable-next=import-outside-toplevel
    from . import device_automation

    await device_automation.async_removed_from_device(hass, device_entry.id)
    return True


async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
    """Unload MQTT dump and publish service when the config entry is unloaded."""
    mqtt_data = get_mqtt_data(hass)
    assert mqtt_data.client is not None
    mqtt_client = mqtt_data.client

    # Unload publish and dump services.
    hass.services.async_remove(
        DOMAIN,
        SERVICE_PUBLISH,
    )
    hass.services.async_remove(
        DOMAIN,
        SERVICE_DUMP,
    )

    # Stop the discovery
    await discovery.async_stop(hass)
    # Unload the platforms
    await asyncio.gather(
        *(
            hass.config_entries.async_forward_entry_unload(entry, component)
            for component in PLATFORMS
        )
    )
    await hass.async_block_till_done()
    # Unsubscribe reload dispatchers
    while reload_dispatchers := mqtt_data.reload_dispatchers:
        reload_dispatchers.pop()()
    # Cleanup listeners
    mqtt_client.cleanup()

    # Trigger reload manual MQTT items at entry setup
    if (mqtt_entry_status := mqtt_config_entry_enabled(hass)) is False:
        # The entry is disabled reload legacy manual items when the entry is enabled again
        mqtt_data.reload_needed = True
    elif mqtt_entry_status is True:
        # The entry is reloaded:
        # Trigger re-fetching the yaml config at entry setup
        mqtt_data.reload_entry = True
    # Reload the legacy yaml platform to make entities unavailable
    await async_reload_integration_platforms(hass, DOMAIN, RELOADABLE_PLATFORMS)
    # Cleanup entity registry hooks
    registry_hooks = mqtt_data.discovery_registry_hooks
    while registry_hooks:
        registry_hooks.popitem()[1]()
    # Wait for all ACKs and stop the loop
    await mqtt_client.async_disconnect()
    # Store remaining subscriptions to be able to restore or reload them
    # when the entry is set up again
    if mqtt_client.subscriptions:
        mqtt_data.subscriptions_to_restore = mqtt_client.subscriptions

    return True
